{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "edcbf738",
   "metadata": {},
   "source": [
    "## Pixtral 12b LMI v12 Deployment Guide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b69ae92d",
   "metadata": {},
   "source": [
    "This notebook demonstrates how to deploy the [Llama3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) model using the LMI v12 container. This example uses the vllm backend.\n",
    "The current implementation of Llama3.2 vision models in vllm 0.6.2 does not support CUDA graphs (eager exectuion required), and does not support multi-image inputs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b26e58",
   "metadata": {},
   "source": [
    "### Install Required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b94d69b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbb8969a",
   "metadata": {},
   "source": [
    "## Create the SageMaker model object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62631e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import image_uris\n",
    "from sagemaker.djl_inference import DJLModel\n",
    "\n",
    "image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.30.0-lmi12.0.0-cu124\"\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "\n",
    "# Once the SageMaker Python SDK PR is merged, we can use image_uris directly\n",
    "# image_uri = image_uris.retrieve(framework=\"djl-lmi\", version=\"0.30.0\", region=\"us-west-2\")\n",
    "\n",
    "model = DJLModel(\n",
    "    role=role,\n",
    "    image_uri=image_uri,\n",
    "    env={\n",
    "        \"HF_MODEL_ID\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
    "        \"HF_TOKEN\": \"<huggingface hub token>\",\n",
    "        \"OPTION_ROLLING_BATCH\": \"vllm\",\n",
    "        \"OPTION_MAX_MODEL_LEN\": \"8192\", # this can be tuned depending on instance type + memory available\n",
    "        \"OPTION_MAX_ROLLING_BATCH_SIZE\": \"16\", # this can be tuned depending on instance type + memory available\n",
    "        \"OPTION_TENSOR_PARALLEL_DEGREE\": \"max\",\n",
    "        \"OPTION_ENFORCE_EAGER\": \"true\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c5e907",
   "metadata": {},
   "source": [
    "## Deploy the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f022cae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = model.deploy(instance_type=\"ml.g6.12xlarge\", initial_instance_count=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25a366fe",
   "metadata": {},
   "source": [
    "## Test prompts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b9bc8cb",
   "metadata": {},
   "source": [
    "The following prompts demonstrate how to use the pixtral-12b model for:\n",
    "- Text only inference\n",
    "- Single image inference\n",
    "- Multi image inference\n",
    "\n",
    "For the multi image inference use-case, we use two images. However, the model is configured to accept up to 4 images in a single prompt. This setting can be tuned with the `OPTION_LIMIT_MM_PER_PROMPT` configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2260ced",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_1_KITTEN = \"https://resources.djl.ai/images/kitten.jpg\"\n",
    "\n",
    "text_only_payload = {\n",
    "    \"messages\": [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": \"I would like to get better at basketball. Can you provide me a 3 month plan to improve my skills?\"\n",
    "        }\n",
    "    ],\n",
    "    \"max_tokens\": 1024,\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "}\n",
    "\n",
    "single_image_payload = {\n",
    "    \"messages\": [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [\n",
    "                {\n",
    "                    \"type\": \"text\",\n",
    "                    \"text\": \"Can you describe the following image and tell me what it contains?\",\n",
    "                },\n",
    "                {\n",
    "                    \"type\": \"image_url\",\n",
    "                    \"image_url\": {\n",
    "                        \"url\": IMAGE_1_KITTEN\n",
    "                    }\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "    ],\n",
    "    \"max_tokens\": 1024,\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2485e48f",
   "metadata": {},
   "source": [
    "# Text Only Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2848eca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Prompt is:\\n {text_only_payload['messages'][0]['content']}\")\n",
    "text_only_output = predictor.predict(text_only_payload)\n",
    "print(\"Response is:\\n\")\n",
    "print(text_only_output['choices'][0]['message']['content'])\n",
    "print('----------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12250682",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from io import BytesIO\n",
    "\n",
    "response_kitten = requests.get(IMAGE_1_KITTEN)\n",
    "img_kitten = Image.open(BytesIO(response_kitten.content))\n",
    "response_truck = requests.get(IMAGE_2_TRUCK)\n",
    "img_truck = Image.open(BytesIO(response_truck.content))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb094d99",
   "metadata": {},
   "source": [
    "# Single Image Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5230292",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"This is the image provided to the model\")\n",
    "img_kitten.show()\n",
    "single_image_output = predictor.predict(single_image_payload)\n",
    "print(single_image_output['choices'][0]['message']['content'])\n",
    "print('----------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f5a68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean up resources\n",
    "predictor.delete_endpoint()\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
